package com.gitee.zhuyb.interceptor;

import com.gitee.zhuyb.domain.PageUtil;
import com.gitee.zhuyb.domain.PageVo;
import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.PreparedStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Properties;

/**
 * @version 1.0.0
 * @Description: #自定义分页拦截器
 * @Date: 2021/10/30 13:28
 * @Copyright (C) ZhuYouBin
 */
@Intercepts({@Signature(type=StatementHandler.class, method="prepare", args=Connection.class)})
public class PageInterceptor implements Interceptor {

    /** 数据库类型 */
    private String dialect;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
        MetaObject delegateMetaObject =MetaObject.forObject(statementHandler);
        PreparedStatementHandler preparedStatementHandler = (PreparedStatementHandler)delegateMetaObject.getValue("delegate");

        BoundSql boundSql = preparedStatementHandler.getBoundSql();
        Object parameterObject = boundSql.getParameterObject();
        Connection connection = (Connection)invocation.getArgs()[0];

        // 获取分页参数
        PageVo pageVo = PageUtil.getPageVo();
        //如果开启了分页
        if(pageVo != null) {
            // 拼接分页参数
            String pageSql = this.getPageSql(boundSql.getSql(), pageVo);
            // 计算总记录数
            this.countTotal(pageVo, parameterObject, preparedStatementHandler, connection);
            // 设置新的sql
            MetaObject boundSqlMetaObject = MetaObject.forObject(boundSql);
            boundSqlMetaObject.setValue("sql", pageSql);
        }
        // 执行后续操作
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        // 设置代理对象
        return Plugin.wrap(target,this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 设置属性
        this.dialect = properties.getProperty("dialect");
    }


    /********************************************************************/

    /**
     * #计算总记录和总分页数
     * @param pageVo
     * @param parameterObject
     * @param statementHandler
     * @param connection
     */
    private void countTotal(PageVo pageVo, Object parameterObject, PreparedStatementHandler statementHandler, Connection connection){
        MetaObject metaObject = MetaObject.forObject(statementHandler);
        MappedStatement mappedStatement = (MappedStatement)metaObject.getValue("mappedStatement");
        BoundSql boundSql = statementHandler.getBoundSql();
        String sql = boundSql.getSql();
        // 获取统计SQL
        sql = this.getCountSql(sql);
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), sql, boundSql.getParameterMappings(), parameterObject);

        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt =connection.prepareStatement(sql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if (rs.next()) {
                int totalRecord = rs.getInt(1);
                pageVo.setTotal(totalRecord);  // 总记录数
                pageVo.setPages((totalRecord-1)/pageVo.getPageSize()+1); // 总页数
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (rs != null)
                    rs.close();
                if (pstmt != null)
                    pstmt.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * #获取分页sql
     * @param sql 拦截前的sql
     * @param pageVo 分页参数对象
     * @return
     */
    private String getPageSql(String sql, PageVo pageVo) {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        if(dialect.equalsIgnoreCase("mysql")){
            return this.getMysqlPageSql(sqlBuffer, pageVo);
        }else if(dialect.equalsIgnoreCase("oralce")){
            return this.getOraclePageSql(sqlBuffer, pageVo);
        }else {
            return sqlBuffer.toString();
        }
    }

    /**
     * #获取统计sql,计算总记录数
     * @param sql 拦截前的sql
     * @return
     */
    private String getCountSql(String sql) {
        int beginIndex = sql.indexOf("from");
        sql = sql.substring(beginIndex);
        sql = "select count(1) " + sql;
        return sql;
    }

    /**
     * #获取mysql分页sql
     * @param sql 拦截前的sql
     * @param pageVo 分页参数
     * @return 返回分页的sql
     */
    private String getMysqlPageSql(StringBuffer sql, PageVo pageVo) {
        sql.append(" limit ")
                .append(pageVo.getPageIndex())
                .append(",")
                .append(pageVo.getPageSize());
        return sql.toString();
    }

    /**
     * #获取oracle分页sql
     * @param sql
     * @param pageVo
     * @return
     */
    private String getOraclePageSql(StringBuffer sql, PageVo pageVo) {
        int page = pageVo.getPageIndex();
        int size = pageVo.getPageSize();
        // 计算记录开始和结束索引
        int startIndex = (page - 1) * size;
        int endIndex = page * size;
        // 拼接小于的索引
        sql.insert(0, "select u.*, rownum r from (")
                .append(") u where rownum <= ")
                .append(endIndex);
        // 拼接大于的索引
        sql.insert(0, "select * from (")
                .append(") where r > ")
                .append(startIndex);
        return sql.toString();
    }
}
